2021.4.13
本次作业将在MNIST手写数字数据集上实现PCA降维及聚类
⭐ train_X.csv: 存放了mnist数据集的data结果,若下载太慢可直接导入
⭐ train_y.csv: 存放了mnist数据集的target结果,若下载太慢可直接导入
可以通过conda install scikit-learn安装该库
import sklearn
from sklearn.datasets import fetch_openml
import numpy as np
import pandas as pd
# mnist=fetch_openml('mnist_784',version=1,cache=True)
X = pd.read_csv('./train_X.csv')
y = pd.read_csv('./train_y.csv')
X = np.array(X)
y = np.array(y)
X = X[:6000, :] # 取前6000个样本
y = y[:6000]
y = np.squeeze(y)
(X.shape, y.shape)
((6000, 784), (6000,))
可以通过调用sklearn库的函数实现,也可以自己实现PCA操作
## TODO
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing
X=preprocessing.scale(X)
pca=PCA()
pca.fit(X)
X_pca=pca.transform(X)
X_pca.shape
(6000, 784)
(提示:将PCA降维后的特征向量转回原始图像的shape来显示)
## TODO
import matplotlib.pyplot as plt
from matplotlib.colors import Colormap
eigenvalues=pca.components_
plt.figure(figsize=(12,10))
for i in range(12):
plt.subplot(3,4,i+1)
plt.imshow(eigenvalues[i].reshape(28,28),cmap='jet')
plt.title('Eigenvalue '+str(i+1))
plt.xticks(())
plt.yticks(())
plt.show()
## TODO
X_2d=X_pca[:,:2]
fig,ax=plt.subplots()
ax.scatter(X_2d[:,0],X_2d[:,1],c=y,cmap='jet')
y_count=dict((i,0) for i in range(10))
for i in range(6000):
y_count[y[i]]+=1
if(y_count[y[i]]%50==0):
ax.annotate(y[i],X_2d[i])
## TODO
from sklearn.cluster import KMeans
X_2d_K3=KMeans(n_clusters=3,init='k-means++').fit(X_2d)
fig,ax=plt.subplots()
ax.scatter(X_2d[:,0],X_2d[:,1],c=X_2d_K3.labels_,cmap='jet')
plt.show()
from sklearn.metrics import adjusted_rand_score
adjusted_rand_score(y,X_2d_K3.labels_)
0.06568206509202973
取类别数从3至20(可以自行扩展),根据ARI找到最佳的聚类结果,并绘制分类图
## TODO
Kmeans_3_20=[KMeans(n_clusters=i,init='k-means++').fit(X_2d) for i in range(3,23)]
ARI_3_20=[adjusted_rand_score(y,Kmeans_3_20[i-3].labels_) for i in range(3,23)]
plt.figure(figsize=(20,20))
for i in range(20):
plt.subplot(4,5,i+1)
plt.scatter(X_2d[:,0],X_2d[:,1],c=Kmeans_3_20[i].labels_,cmap='jet')
plt.title("categories: "+str(i+3)+"\nARI: %.3f"%ARI_3_20[i])
plt.show()